package cn.plugins.generator.utils

import cn.plugins.generator.entity.ColumnEntity
import cn.plugins.generator.entity.DatabaseInfoEntity
import cn.plugins.generator.entity.TableEntity
import com.alibaba.druid.pool.DruidDataSourceFactory
import java.sql.*
import javax.sql.DataSource


/**
 * 创建于 2020-01-08 15:48
 *
 * @author jiangyun
 * @类说明：数据库连接工具类
 */
object MysqlDatabaseUtil {
    private var connection: Connection? = null

    /**
     * 获取数据库连接
     */
    /*@Throws(Exception::class)
    fun getConnection(databaseInfo: DatabaseInfoEntity?): Connection? {
        return try {
            if (this.connection != null) {
                return this.connection
            }
            Class.forName("com.mysql.cj.jdbc.Driver")
            DriverManager.getConnection(
                databaseInfo?.url,
                databaseInfo?.userName,
                databaseInfo?.password
            )
        } catch (e: SQLException) {
            e.printStackTrace()
            println("数据库连接失败！")
            throw e
        } catch (e: ClassNotFoundException) {
            e.printStackTrace()
            println("获取数据库驱动失败！")
            throw e
        }
    }*/

    /**
     * 获取数据库连接(连接池)
     */
    private fun getConnection(databaseInfo: DatabaseInfoEntity?): Connection? {
        if (this.connection != null) {
            return this.connection
        }

        val properties = PropertiesUtil.getMysqlConnectionProperties(databaseInfo!!.url, databaseInfo.userName, databaseInfo.password)
        val dataSource: DataSource = DruidDataSourceFactory.createDataSource(properties)
        connection = dataSource.getConnection()

        return connection
    }

    /**
     * 获取表信息
     *
     * @param url           连接地址
     * @param databaseName  数据库名
     * @param userName      用户名
     * @param password      密码
     * @param tableName     表名称
     */
    @Throws(Exception::class)
    fun getTableInfo(
        url: String,
        databaseName: String,
        userName: String,
        password: String,
        tableName: String
    ): TableEntity? {
        val connection: Connection? = getConnection(DatabaseInfoEntity(url, userName, password))
        if (connection != null) {
            var ps: PreparedStatement? = null
            var rs: ResultSet? = null
            return try {
                val sql = """
                    SELECT 
                    TABLE_NAME, 
                    table_comment COMMENT 
                    FROM information_schema.tables 
                    WHERE TABLE_NAME = '${tableName}' AND TABLE_SCHEMA = '${databaseName}';
                """.trimIndent()
                ps = connection.prepareStatement(sql)
                rs = ps.executeQuery()

                val tableInfo = TableEntity()
                if (rs.next()) {
                    val tempTableName: String = if (tableName.isNotBlank() && tableName.substring(
                            tableName.length - 2,
                            tableName.length - 1
                        ) == "表"
                    ) {
                        tableName.substring(0, tableName.length - 2)
                    } else {
                        tableName
                    }
                    tableInfo.tableName = tempTableName
                    tableInfo.comment = rs.getString("comment")
                    if (tableInfo.comment!!.endsWith("表")) {
                        tableInfo.comment = tableInfo.comment!!.substring(0, tableInfo.comment!!.length - 1)
                    }
                    if (tableInfo.comment.isNullOrBlank()) {
                        tableInfo.comment = "暂无表备注"
                    }
                }

                tableInfo
            } catch (e: SQLException) {
                throw e
            } finally {
                try {
                    ps!!.close()
                    rs!!.close()
                } catch (e: SQLException) {
                    e.printStackTrace()
                }
            }
        }

        return null
    }

    /**
     * 获取表信息
     *
     * @param url           连接地址
     * @param databaseName  数据库名
     * @param userName      用户名
     * @param password      密码
     * @param tableName     表名称
     */
    @Throws(Exception::class)
    fun getAllTableInfo(
        url: String,
        databaseName: String,
        userName: String,
        password: String
    ): MutableList<TableEntity>? {
        val connection: Connection? = getConnection(DatabaseInfoEntity(url, userName, password))
        if (connection != null) {
            var ps: PreparedStatement? = null
            var rs: ResultSet? = null
            return try {
                val tables: MutableList<TableEntity> = mutableListOf()
                val sql = """
                    SELECT
                     table_name,
                     table_comment,
                     table_collation,
                     engine,
                     create_options
                     FROM information_schema.tables
                     WHERE TABLE_SCHEMA = '${databaseName}' AND table_name != 'flyway_schema_history' ORDER BY TABLE_NAME;
                """.trimIndent()
                ps = connection.prepareStatement(sql)
                rs = ps.executeQuery()

                var tableInfo: TableEntity?
                while (rs.next()) {
                    tableInfo = TableEntity()
                    tableInfo.tableName = rs.getString("table_name")
                    tableInfo.comment = rs.getString("table_comment")
                    if (tableInfo.comment.isNullOrBlank()) {
                        tableInfo.comment = "暂无表备注"
                    }
                    tableInfo.collation = rs.getString("table_collation")
                    tableInfo.engine = rs.getString("engine")
                    tableInfo.defaultCharset = tableInfo.collation!!.split("_")[0]
                    if (tableInfo.defaultCharset.isNullOrBlank()) {
                        tableInfo.defaultCharset = "utf8"
                    }
                    tableInfo.rowFormat = rs.getString("create_options")
                    if (tableInfo.rowFormat.isNullOrBlank()) {
                        tableInfo.rowFormat = ""
                    } else {
                        tableInfo.rowFormat = tableInfo.rowFormat!!.toUpperCase()
                    }

                    tables.add(tableInfo)
                }

                tables
            } catch (e: SQLException) {
                throw e
            } finally {
                try {
                    ps!!.close()
                    rs!!.close()
                } catch (e: SQLException) {
                    e.printStackTrace()
                }
            }
        }

        return null
    }

    /**
     * 获取表的列值信息
     *
     * @param url           连接地址
     * @param databaseName  数据库名称
     * @param userName      用户名
     * @param password      密码
     * @param tableName     表名称
     */
    @Throws(Exception::class)
    fun findTableColumns(
        url: String,
        databaseName: String,
        userName: String,
        password: String,
        tableName: String
    ): MutableList<ColumnEntity>? {
        val conn = getConnection(DatabaseInfoEntity(url, userName, password))
        var ps: PreparedStatement? = null
        var rs: ResultSet? = null
        return try {
            val sql = """
                SELECT column_name              AS columnName,
                       data_type                AS dataType,
                       column_default           AS columnDefault,
                       is_nullable              AS isNullable,
                       column_type              AS columnType,
                       extra,
                       character_maximum_length AS characterMaximumLength,
                       column_comment           AS columnComment,
                       column_key               AS columnKey
                FROM information_schema.columns
                WHERE table_name = '${tableName}'
                  and table_schema = '${databaseName}'
                order by ordinal_position;
            """.trimIndent()
            ps = conn!!.prepareStatement(sql)
            rs = ps.executeQuery()
            val columns: MutableList<ColumnEntity> = mutableListOf()
            var column: ColumnEntity?
            while (rs.next()) {
                column = ColumnEntity()
                column.columnName = rs.getString("columnName")
                column.dataType = rs.getString("dataType")
                column.columnDefault = rs.getString("columnDefault")
                column.isNullable = rs.getString("isNullable")
                column.nullFlag = rs.getString("isNullable")
                if (column.isNullable.isNullOrBlank()) {
                    column.isNullable = "YES"
                    column.nullFlag = "YES"
                }
                // 处理markdown文档的默认值
                val dataType: String? = column.dataType
                if (column.isNullable == "YES" && (dataType == "char" || dataType == "varchar")) {
                    column.defaultValue = "NULL"
                }
                if (column.isNullable == "NO" && column.columnDefault == "" && (dataType == "char" || dataType == "varchar")) {
                    column.defaultValue = "Empty String"
                }
                column.characterMaximumLength = rs.getString("characterMaximumLength")
                if (column.characterMaximumLength.isNullOrBlank()) {
                    column.characterMaximumLength = ""
                }
                column.columnComment = rs.getString("columnComment")
                if (column.columnComment.isNullOrBlank()) {
                    column.columnComment = ""
                }
                column.columnKey = rs.getString("columnKey")
                if (column.columnKey.isNullOrBlank()) {
                    column.columnKey = ""
                }
                column.columnType = rs.getString("columnType")
                column.extra = rs.getString("extra")
                if (column.extra.isNullOrBlank()) {
                    column.extra = ""
                } else {
                    column.extra = column.extra!!.toUpperCase()
                }

                columns.add(column)
            }
            columns
        } catch (e: SQLException) {
            throw e
        } finally {
            try {
                ps!!.close()
                rs!!.close()
            } catch (e: SQLException) {
                e.printStackTrace()
            }
        }
    }

}